-
Notifications
You must be signed in to change notification settings - Fork 75
[ROCm] Use fine-grain fence in reduction #2553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* The global reduction path in reduction kernel currently has two threadfence operation * The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block * For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations) * So using fine-grain fence gives significant performance boost for AMD gpus. * We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer Co-author: @amd-hhashemi, @jeffdaily **Reproducer**: ```import time import torch shapes = [(2, 896, 59, 91), ] dims = [(2, 3), ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.bfloat16) x = x.to(memory_format=torch.channels_last) for _ in range(20): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) torch.cuda.synchronize() start_evt = torch.cuda.Event(enable_timing=True) end_evt = torch.cuda.Event(enable_timing=True) start_evt.record() for _ in range(100): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) end_evt.record() torch.cuda.synchronize() print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us") ```
|
Results (MI300X): |
|
Jenkins build for baddc98b5389ba858f9677a5a2738914e429192d commit is in progress |
|
! cherry-pick --onto release/2.8 rocm7.1_internal_testing |
cherry-pick of pytorch#160979 Less-performant fix until pytorch#161180 is finalized * The global reduction path in reduction kernel currently has two threadfence operation * The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block * For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations) * So using fine-grain fence gives significant performance boost for AMD gpus. * We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer Co-author: @amd-hhashemi, @jeffdaily **Reproducer**: ```import time import torch shapes = [(2, 896, 59, 91), ] dims = [(2, 3), ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.bfloat16) x = x.to(memory_format=torch.channels_last) for _ in range(20): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) torch.cuda.synchronize() start_evt = torch.cuda.Event(enable_timing=True) end_evt = torch.cuda.Event(enable_timing=True) start_evt.record() for _ in range(100): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) end_evt.record() torch.cuda.synchronize() print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us") ``` Fixes SWDEV-545710
cherry-pick of pytorch#160979 Less-performant fix until pytorch#161180 is finalized * The global reduction path in reduction kernel currently has two threadfence operation * The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block * For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations) * So using fine-grain fence gives significant performance boost for AMD gpus. * We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer Co-author: @amd-hhashemi, @jeffdaily **Reproducer**: ```import time import torch shapes = [(2, 896, 59, 91), ] dims = [(2, 3), ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.bfloat16) x = x.to(memory_format=torch.channels_last) for _ in range(20): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) torch.cuda.synchronize() start_evt = torch.cuda.Event(enable_timing=True) end_evt = torch.cuda.Event(enable_timing=True) start_evt.record() for _ in range(100): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) end_evt.record() torch.cuda.synchronize() print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us") ``` Fixes SWDEV-545710
cherry-pick of pytorch#160979 Less-performant fix until pytorch#161180 is finalized * The global reduction path in reduction kernel currently has two threadfence operation * The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block * For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations) * So using fine-grain fence gives significant performance boost for AMD gpus. * We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer Co-author: @amd-hhashemi, @jeffdaily **Reproducer**: ```import time import torch shapes = [(2, 896, 59, 91), ] dims = [(2, 3), ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.bfloat16) x = x.to(memory_format=torch.channels_last) for _ in range(20): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) torch.cuda.synchronize() start_evt = torch.cuda.Event(enable_timing=True) end_evt = torch.cuda.Event(enable_timing=True) start_evt.record() for _ in range(100): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) end_evt.record() torch.cuda.synchronize() print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us") ``` Fixes SWDEV-545710
|
Created branch autogenerated/release/2.8_cherry-pick_pr-2553 and #2560 Created branch autogenerated/rocm7.1_internal_testing_cherry-pick_pr-2553 and #2562 |
cherry-pick of pytorch#160979 Less-performant fix until pytorch#161180 is finalized * The global reduction path in reduction kernel currently has two threadfence operation * The first threadfence is executed by all threads in all the blocks, whereas the second threadfence is only run by threads in a single block * For AMD gpus, threadfence is a heavy weight operation, esp. when run by all the threads in the system (due to cross-XCD synchronizations) * So using fine-grain fence gives significant performance boost for AMD gpus. * We do a release fence when threads write to reduce buffer in global memory; and then do a acquire fence when threads read from the reduce buffer Co-author: @amd-hhashemi, @jeffdaily **Reproducer**: ```import time import torch shapes = [(2, 896, 59, 91), ] dims = [(2, 3), ] for i, shape in enumerate(shapes): x = torch.randn(shape, device='cuda', dtype=torch.bfloat16) x = x.to(memory_format=torch.channels_last) for _ in range(20): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) torch.cuda.synchronize() start_evt = torch.cuda.Event(enable_timing=True) end_evt = torch.cuda.Event(enable_timing=True) start_evt.record() for _ in range(100): _ = torch.sum(x, dims[i], keepdim=True, dtype=torch.bfloat16) end_evt.record() torch.cuda.synchronize() print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us") ``` Fixes SWDEV-545710
|
Created branch autogenerated/release/2.8_cherry-pick_pr-2553 and #2561 Created branch autogenerated/rocm7.1_internal_testing_cherry-pick_pr-2553 and #2563 Comment processed by Build |
#2561) Cherry-pick of #2553 Co-authored-by: Jerry Mannil <[email protected]>
…e in reduction (#2563) Cherry-pick of #2553 Co-authored-by: Jerry Mannil <[email protected]>
…e in reduction (#2563) Cherry-pick of #2553 Co-authored-by: Jerry Mannil <[email protected]>
cherry-pick of pytorch#160979
Less-performant fix until pytorch#161180 is finalized
Co-author: @amd-hhashemi, @jeffdaily
Reproducer:
Fixes SWDEV-545710
Cherry-picked to release/2.8 branch via #2561
Cherry-picked to rocm7.1_internal_testing branch via #2563